//Execute.scala
package mycore

import chisel3._
import chisel3.util._

class Execute extends Module{
  val io = IO(new Bundle{

    val ENABLE = Input(Bool())
    val idex = Flipped(new IDEX)
    val exmem = new EXMEM
    val cancel = Output(Bool())
    val nextPC = new NEXTPC
    val redesMem = Flipped(new WBDATA)
    val redesWb = Flipped(new WBDATA)
    val TimeOver = Input(Bool())
    
  })

  //本流水级需要使用的信号
  val exs = RegEnable(io.idex, WireInit(0.U.asTypeOf(new IDEX())), io.ENABLE)
  val Valid = exs.Valid
  val Inst = exs.Inst
  val PC = exs.PC
  val isa = exs.isa
  val imm = exs.imm
  val wen = exs.wen
  val rs1 = exs.Inst(19,15)
  val rs2 = exs.Inst(24,20)
  val regdes = exs.regdes

  val redesMem_src1_hit = io.redesMem.wen && (rs1 === io.redesMem.regdes) && (rs1 =/= 0.U)
  val redesMem_src2_hit = io.redesMem.wen && (rs2 === io.redesMem.regdes) && (rs2 =/= 0.U)
  val redesWb_src1_hit = io.redesWb.wen && (rs1 === io.redesWb.regdes) && (rs1 =/= 0.U)
  val redesWb_src2_hit = io.redesWb.wen && (rs2 === io.redesWb.regdes) && (rs2 =/= 0.U)
  val src1 = Mux(redesMem_src1_hit, io.redesMem.data, Mux(redesWb_src1_hit, io.redesWb.data, exs.src1))
  val src2 = Mux(redesMem_src2_hit, io.redesMem.data, Mux(redesWb_src2_hit, io.redesWb.data, exs.src2))

  /*------------------------------*/
  /*             本体             */
  /*------------------------------*/
  val alu = Module(new ArithmeticLogicalUnit)
  alu.io.isa := isa
  alu.io.imm := imm
  alu.io.src1 := src1
  alu.io.src2 := src2
  val aluresult = alu.io.result

  val bc = Module(new BranchController) //include jump&link
  bc.io.isa := isa
  bc.io.imm := imm
  bc.io.src1 := src1
  bc.io.src2 := src2
  bc.io.pc := PC
  val branch = bc.io.branch
  val target = bc.io.target

  val link  = SignExt((isa.JAL | isa.JALR).asUInt, 64)  & (PC + 4.U)
  val auipc = SignExt(isa.AUIPC.asUInt, 64) & (PC + imm.U)

  val csru = Module(new ControlStatusRegisterUnit)
  csru.io.ENABLE  := io.ENABLE
  csru.io.pc      := PC
  csru.io.inst    := Mux(Valid, Inst, 0x13.U)
  csru.io.src     := src1
  csru.io.TimeOver:= io.TimeOver
  val csrData = csru.io.dataout
  val TimerInterrupt = csru.io.TimerInterrupt
  val EnvironmentCall = csru.io.EnvironmentCall

  io.nextPC.trap := Valid && (TimerInterrupt || EnvironmentCall)
  io.nextPC.mtvec := csru.io.CSRState.mtvec
  io.nextPC.mret := Valid && isa.MRET
  io.nextPC.mepc := csru.io.CSRState.mepc
  io.nextPC.branch := Valid && branch
  io.nextPC.target := target

  io.cancel := Valid && (branch || TimerInterrupt || EnvironmentCall || isa.MRET)
  /*------------------------------*/

  //传递给下一个流水级的信号
  io.exmem.Valid    := Mux(TimerInterrupt, false.B, Valid)
  io.exmem.Inst     := Inst
  io.exmem.PC       := PC
  io.exmem.isa      := isa
  io.exmem.src1     := src1
  io.exmem.src2     := src2
  io.exmem.imm      := imm
  io.exmem.wen      := wen
  io.exmem.regdes   := regdes
  io.exmem.aluresult:= aluresult
  io.exmem.branch   := branch
  io.exmem.target   := target
  io.exmem.link     := link
  io.exmem.auipc    := auipc
  io.exmem.csrData  := csrData
  io.exmem.TimerInterrupt := TimerInterrupt
  io.exmem.EnvironmentCall := EnvironmentCall
  io.exmem.csr.mstatus := csru.io.CSRState.mstatus
  io.exmem.csr.mcause := csru.io.CSRState.mcause
  io.exmem.csr.mepc := csru.io.CSRState.mepc
  io.exmem.csr.mie := csru.io.CSRState.mie
  io.exmem.csr.mscratch := csru.io.CSRState.mscratch
  io.exmem.csr.medeleg := csru.io.CSRState.medeleg
  io.exmem.csr.mtvec := csru.io.CSRState.mtvec
  io.exmem.csr.mhartid := 0.U
  io.exmem.csr.mcycle := 0.U
  io.exmem.csr.mip := 0.U
  
}